import torch

from src.gfn.gfn import GFlowNet, LogProbsTensor, LossTensor, CumulativeLogProbsTensor, MeanLossTensor
from src.utils.trajectories import Trajectories

class STBGFlowNet(GFlowNet):

    def __init__(self,
                 env,
                 config,
                forward_model,
                backward_model,
                logF_model,
                lamda: float = 0.9,
                tied: bool = False,
                ):
        
        super().__init__(env, config, forward_model, backward_model)
        self.logF_model = logF_model.to(self.device)
        assert self.logF_model.output_dim == 1, "LogF model must output a scalar."

        self.optimizer = self._init_optimizer(tied, include_logF=True)
        self.scheduler = self._init_scheduler(config["gfn"]["lr_schedule"])
        self.lamda = lamda

        # sum((n+1-i) * lambda ^i, i=1..n)
        # https://www.wolframalpha.com/input?i=sum%28%28n%2B1-i%29+*+lambda+%5Ei%2C+i%3D1..n%29
        self.lamda_normalisation = (self.lamda / (self.lamda - 1)**2) * (self.lamda *(self.lamda**self.trajectory_length - 1) - self.lamda * self.trajectory_length + self.trajectory_length)

    def _compute_loss_precursors(self, trajs: Trajectories, head=None):
        """
        Compute the log probabilities.
        """
        trajs.compute_logPF(self, head)
        trajs.compute_logPB(self)
        trajs.compute_logF(self)

    def _cumulative_logprobs(self, logprobs: LogProbsTensor) -> CumulativeLogProbsTensor:
        """
        Given log probabilities, computes the cummulative log probabilities.

        Input:
            logprobs - torch.tensor of size (batch_size, trajectory_length)
        Output:
            cummulative_logprobs - torch.tensor of size (batch_size, trajectory_length + 1), where the j-th element of the i-th row 
            is the sum of the first j elements of the i-th row of the input tensor. A j=0 column is prepended with zeros.
        """
        return torch.cat((
            torch.zeros((self.batch_size, 1), device=logprobs.device),
            torch.cumsum(logprobs, dim=1)
        ), dim=1)
    
    def _get_mean_subtraj_loss(self, trajs: Trajectories) -> MeanLossTensor:
        """
        Get the mean loss across all subtrajectories of the same length.

        Ouput:
            mean_loss - torch.tensor of size (trajectory_length), whose i-th element is mean loss over all sub-trajectories of length i in the batch.
        """

        logPF_cum = self._cumulative_logprobs(trajs.log_fullPF) #CumulativeLogProbsTensor
        logPB_cum = self._cumulative_logprobs(trajs.log_fullPB) #CumulativeLogProbsTensor

        mean_losses = torch.zeros(trajs.length, device=trajs.device)

        # B = batch_size
        # L = trajectory_length
        # l = subtreajectory length

        # For subtrajectories of length l = 1, 2, ..., L
        for l in range(1, 1 + trajs.length):
            # There are N = L + 1 - l subtrajectories of length l in a trajectory of length L

            # Compute flows at the starting index of each subtrajectory
            logF_m = (  # logF_0 is the logZ estimate (the same for every trajectory)
                trajs.logF if l == 1 else trajs.logF[:, : -(l-1)]       
            ) # shape (B, N)

            # Compute the sum of the logPFs along each subtrajectory 
            logPF_cum_m_to_n = logPF_cum[:, l : ] - logPF_cum[:, : - l]  # shape (B, N)

            # Compute the flows at the terminating index of each subtrajectory
            logF_n = torch.cat((trajs.logF[:,l:], trajs.log_rewards.clamp_min(self.log_reward_clip_min).unsqueeze(1)), dim=1)  # shape (B, N)

            # Compute sum of the logPBs, shape (B, N)
            logPB_cum_m_to_n = logPB_cum[:, l :] - logPB_cum[:, : - l]

            # Compute scores, shape (B, N)
            score = logF_m + logPF_cum_m_to_n - logPB_cum_m_to_n - logF_n

            # Compute mean squared score across all subtrajectorys of this length, across the entire batch
            mean_losses[l - 1] = score.pow(2).mean()

        return mean_losses

    def loss(self, trajs: Trajectories, head=None) -> LossTensor:
        """
        Subtrajectory balance loss using the "geometric_within" weighting of the original paper.
        """
        self._compute_loss_precursors(trajs, head)

        mean_subtraj_losses = self._get_mean_subtraj_loss(trajs)  # shape (trajectory_length)

        # Indices 0, 2, ..., L-1 (where L = trajectory_length), one less than the length of the subtrajectory
        subtraj_indices = torch.arange(self.trajectory_length, dtype=torch.float, device=mean_subtraj_losses.device)

        # Lambda prefactor for subtrajectories of length l is given by (lamda ** l)
        lamda_exponents = self.lamda ** (subtraj_indices + 1)

        # The number of subtrajectories of length l is given by
        subtraj_multiplicity = self.trajectory_length - subtraj_indices

        # Compute the weighted mean loss across all subtrajectories
        weighted_subtraj_losses = mean_subtraj_losses * lamda_exponents * subtraj_multiplicity
        loss = torch.sum(weighted_subtraj_losses) / self.lamda_normalisation    

        return loss
    

